"""
Korquad (Korean QA Dataset for Machine Reading Comprehension)
https://arxiv.org/abs/1909.07005

Machine Reading Comprehension (MRC) is a task that requires machine to understand natural language and answer questions by reading a document. 
It is the core of automatic response technology such as chatbots and automatized customer supporting systems. 
We present Korean Question Answering Dataset(KorQuAD), a large-scale Korean dataset for extractive machine reading comprehension task. 
It consists of 70,000+ human generated question-answer pairs on Korean Wikipedia articles.
We release KorQuAD1.0 and launch a challenge at this https URL to encourage the development of multilingual natural language processing research.
"""
import datasets
from math import exp
from lm_eval.base import rf, Task
from functools import partial
from packaging import version


_CITATION = """
@article{lim2019korquad1,
  title={Korquad1. 0: Korean qa dataset for machine reading comprehension},
  author={Lim, Seungyoung and Kim, Myungji and Lee, Jooyoul},
  journal={arXiv preprint arXiv:1909.07005},
  year={2019}
"""


def _squad_metric(predictions, references):
    squad_metric = datasets.load_metric("squad")
    return squad_metric.compute(predictions=predictions, references=references)


def _squad_agg(key, items):
    predictions, references = zip(*items)

    return _squad_metric(predictions=predictions, references=references)[key]


class Korquad(Task):
    VERSION = 1
    DATASET_PATH = "KETI-AIR/korquad"
    DATASET_NAME = "v1.0"

    # # HF changed squad on us so we have to make sure we aren't running the old one
    # assert version.parse(datasets.__version__) >= version.parse("1.11.0"), "datasets v1.11.0 or later required for SQuAD"

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def training_docs(self):
        return self.dataset["train"]

    def validation_docs(self):
        return self.dataset["dev"]

    def doc_to_text(self, doc):
        return '제목: ' + doc['title'] + '\n\n' + '본문: ' + doc['context'] + '\n\n' + '질문: ' + doc['question'] + '\n\n' + '답:'

    def doc_to_target(self, doc):
        answer_list = doc['answers']['text']
        answer = answer_list[0]
        
        return " " + answer

    def construct_requests(self, doc, ctx):
        """ Uses RequestFactory to construct Requests and returns an iterable of 
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural 
            language description, as well as the few shot examples, and the question
            part of the document for `doc`. 
        """
        continuation = rf.greedy_until(ctx, ['\n'])
        return continuation
    
    def process_results(self, doc, results):
        """Take a single document and the LM results and evaluates, returning a 
        dict where keys are the names of submetrics and values are the values of 
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
        continuation = results
        
        predictions = {
            'id': doc['id'],
            'prediction_text': continuation
        }

        references = {
            'id': doc['id'],
            'answers': doc['answers'],
        }

        return { 
            'exact_match': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
            'f1': (predictions, references), #  The F-score of predicted tokens versus the gold answer
        }

    def aggregation(self):
        """
        :returns: {str: [float] -> float}
            A dictionary where keys are the names of submetrics and values are 
            functions that aggregate a list of metrics
        """
        return { 
            'exact_match': partial(_squad_agg, 'exact_match'), # Exact match (the normalized answer exactly match the gold answer)
            'f1': partial(_squad_agg, 'f1'), #  The F-score of predicted tokens versus the gold answer
        }

    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are 
            whether a higher value of the submetric is better
        """
        return { 
            'exact_match': True, # Exact match (the normalized answer exactly match the gold answer)
            'f1': True, #  The F-score of predicted tokens versus the gold answer
        }
